# -*- coding: utf-8 -*-
# @Time    : 2021/6/3 17:09
# @Author  : 万方名
# @FileName: SVM_classfication.py

from sklearn import svm

X = [[0, 0], [1, 1]]
y = [0, 1]
clf = svm.SVC()
clf.fit(X, y)

pre = clf.predict([[2., 2.]])
print(pre)


# get support vectors
print(clf.support_vectors_)
# array([[0., 0.],
#        [1., 1.]])

# get indices of support vectors
print(clf.support_)
# array([0, 1]...)
# get number of support vectors for each class
print(clf.n_support_)


# 多分类
X = [[0], [1], [2], [3]]
Y = [0, 1, 2, 3]
clf = svm.SVC(decision_function_shape='ovo')
clf.fit(X, Y)

dec = clf.decision_function([[1]])
dec.shape[1] # 4 classes: 4*3/2.KNN = 6

clf.decision_function_shape = "ovr"
dec = clf.decision_function([[1]])
dec.shape[1] # 4 classes


# SVM做回归
X = [[0, 0], [2, 2]]
y = [0.5, 2.5]
regr = svm.SVR()
regr.fit(X, y)

regr.predict([[1, 1]])


# SVM试用不同的核函数
linear_svc = svm.SVC(kernel='linear')
linear_svc.kernel

rbf_svc = svm.SVC(kernel='rbf')
rbf_svc.kernel